{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "lovely-budapest",
   "metadata": {},
   "source": [
    "# This is a notebook that shows how to produce Grad-CAM visualizations for ALBEF"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "czech-surprise",
   "metadata": {},
   "source": [
    "# 1. Set the paths for model checkpoint and configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "institutional-sarah",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = '../VL/Example/refcoco.pth'\n",
    "bert_config_path = 'configs/config_bert.json'\n",
    "use_cuda = False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "lovely-passage",
   "metadata": {},
   "source": [
    "# 2. Model defination"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "documented-symbol",
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "from models.vit import VisionTransformer\n",
    "from models.xbert import BertConfig, BertModel\n",
    "from models.tokenization_bert import BertTokenizer\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "from torchvision import transforms\n",
    "\n",
    "import json\n",
    "\n",
    "class VL_Transformer_ITM(nn.Module):\n",
    "    def __init__(self,                 \n",
    "                 text_encoder = None,\n",
    "                 config_bert = ''\n",
    "                 ):\n",
    "        super().__init__()\n",
    "    \n",
    "        bert_config = BertConfig.from_json_file(config_bert)\n",
    "\n",
    "        self.visual_encoder = VisionTransformer(\n",
    "            img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, \n",
    "            mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) \n",
    "\n",
    "        self.text_encoder = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False)   \n",
    "        \n",
    "        self.itm_head = nn.Linear(768, 2) \n",
    "\n",
    "        \n",
    "    def forward(self, image, text):\n",
    "        image_embeds = self.visual_encoder(image) \n",
    "\n",
    "        image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)\n",
    "\n",
    "        output = self.text_encoder(text.input_ids, \n",
    "                                attention_mask = text.attention_mask,\n",
    "                                encoder_hidden_states = image_embeds,\n",
    "                                encoder_attention_mask = image_atts,      \n",
    "                                return_dict = True,\n",
    "                               )     \n",
    "           \n",
    "        vl_embeddings = output.last_hidden_state[:,0,:]\n",
    "        vl_output = self.itm_head(vl_embeddings)   \n",
    "        return vl_output"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "renewable-eight",
   "metadata": {},
   "source": [
    "# 3. Text Preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "optional-brooklyn",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "def pre_caption(caption,max_words=30):\n",
    "    caption = re.sub(\n",
    "        r\"([,.'!?\\\"()*#:;~])\",\n",
    "        '',\n",
    "        caption.lower(),\n",
    "    ).replace('-', ' ').replace('/', ' ')\n",
    "\n",
    "    caption = re.sub(\n",
    "        r\"\\s{2,}\",\n",
    "        ' ',\n",
    "        caption,\n",
    "    )\n",
    "    caption = caption.rstrip('\\n') \n",
    "    caption = caption.strip(' ')\n",
    "\n",
    "    #truncate caption\n",
    "    caption_words = caption.split(' ')\n",
    "    if len(caption_words)>max_words:\n",
    "        caption = ' '.join(caption_words[:max_words])            \n",
    "    return caption"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "based-roads",
   "metadata": {},
   "source": [
    "# 4. Image Preprocessing and Postpressing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "subsequent-flesh",
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "\n",
    "import cv2\n",
    "import numpy as np\n",
    "\n",
    "from skimage import transform as skimage_transform\n",
    "from scipy.ndimage import filters\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "def getAttMap(img, attMap, blur = True, overlap = True):\n",
    "    attMap -= attMap.min()\n",
    "    if attMap.max() > 0:\n",
    "        attMap /= attMap.max()\n",
    "    attMap = skimage_transform.resize(attMap, (img.shape[:2]), order = 3, mode = 'constant')\n",
    "    if blur:\n",
    "        attMap = filters.gaussian_filter(attMap, 0.02*max(img.shape[:2]))\n",
    "        attMap -= attMap.min()\n",
    "        attMap /= attMap.max()\n",
    "    cmap = plt.get_cmap('jet')\n",
    "    attMapV = cmap(attMap)\n",
    "    attMapV = np.delete(attMapV, 3, 2)\n",
    "    if overlap:\n",
    "        attMap = 1*(1-attMap**0.7).reshape(attMap.shape + (1,))*img + (attMap**0.7).reshape(attMap.shape+(1,)) * attMapV\n",
    "    return attMap\n",
    "\n",
    "\n",
    "normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n",
    "\n",
    "transform = transforms.Compose([\n",
    "    transforms.Resize((384,384),interpolation=Image.BICUBIC),\n",
    "    transforms.ToTensor(),\n",
    "    normalize,\n",
    "])     "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "occasional-trace",
   "metadata": {},
   "source": [
    "# 5. Load model and tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "qualified-sleep",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['bert.pooler.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'bert.pooler.dense.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.10.crossattention.self.key.weight', 'bert.encoder.layer.6.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.10.crossattention.output.dense.bias', 'bert.encoder.layer.8.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.9.crossattention.output.dense.bias', 'bert.encoder.layer.6.crossattention.output.dense.bias', 'bert.encoder.layer.9.crossattention.self.value.bias', 'bert.encoder.layer.9.crossattention.self.key.bias', 'bert.encoder.layer.8.crossattention.self.key.weight', 'bert.encoder.layer.6.crossattention.self.query.weight', 'bert.encoder.layer.8.crossattention.self.value.weight', 'bert.encoder.layer.9.crossattention.self.query.weight', 'bert.encoder.layer.9.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.11.crossattention.self.value.weight', 'bert.encoder.layer.7.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.8.crossattention.self.key.bias', 'bert.encoder.layer.7.crossattention.self.key.weight', 'bert.encoder.layer.7.crossattention.output.dense.weight', 'bert.encoder.layer.11.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.9.crossattention.self.key.weight', 'bert.encoder.layer.9.crossattention.self.value.weight', 'bert.encoder.layer.8.crossattention.self.query.bias', 'bert.encoder.layer.11.crossattention.self.key.bias', 'bert.encoder.layer.10.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.8.crossattention.output.dense.bias', 'bert.encoder.layer.7.crossattention.output.dense.bias', 'bert.encoder.layer.7.crossattention.self.value.weight', 'bert.encoder.layer.11.crossattention.self.query.bias', 'bert.encoder.layer.9.crossattention.self.query.bias', 'bert.encoder.layer.6.crossattention.output.dense.weight', 'bert.encoder.layer.10.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.6.crossattention.self.value.bias', 'bert.encoder.layer.8.crossattention.self.query.weight', 'bert.encoder.layer.11.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.11.crossattention.self.value.bias', 'bert.encoder.layer.6.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.9.crossattention.output.dense.weight', 'bert.encoder.layer.6.crossattention.self.query.bias', 'bert.encoder.layer.8.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.11.crossattention.self.key.weight', 'bert.encoder.layer.7.crossattention.self.key.bias', 'bert.encoder.layer.6.crossattention.self.value.weight', 'bert.encoder.layer.7.crossattention.self.query.bias', 'bert.encoder.layer.8.crossattention.self.value.bias', 'bert.encoder.layer.6.crossattention.self.key.bias', 'bert.encoder.layer.7.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.10.crossattention.self.value.weight', 'bert.encoder.layer.10.crossattention.self.query.weight', 'bert.encoder.layer.11.crossattention.output.dense.weight', 'bert.encoder.layer.11.crossattention.output.dense.bias', 'bert.encoder.layer.9.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.7.crossattention.self.query.weight', 'bert.encoder.layer.10.crossattention.output.dense.weight', 'bert.encoder.layer.6.crossattention.self.key.weight', 'bert.encoder.layer.10.crossattention.self.key.bias', 'bert.encoder.layer.8.crossattention.output.dense.weight', 'bert.encoder.layer.7.crossattention.self.value.bias', 'bert.encoder.layer.11.crossattention.self.query.weight', 'bert.encoder.layer.10.crossattention.self.value.bias', 'bert.encoder.layer.10.crossattention.self.query.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
    "\n",
    "model = VL_Transformer_ITM(text_encoder='bert-base-uncased', config_bert=bert_config_path)\n",
    "\n",
    "checkpoint = torch.load(model_path, map_location='cpu')              \n",
    "msg = model.load_state_dict(checkpoint,strict=False)\n",
    "model.eval()\n",
    "\n",
    "block_num = 8\n",
    "\n",
    "model.text_encoder.base_model.base_model.encoder.layer[block_num].crossattention.self.save_attention = True\n",
    "\n",
    "if use_cuda:\n",
    "    model.cuda() "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "apparent-captain",
   "metadata": {},
   "source": [
    "# 6. Load Image and Text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "finite-angle",
   "metadata": {},
   "outputs": [],
   "source": [
    "image_path = 'examples/image0.jpg'\n",
    "image_pil = Image.open(image_path).convert('RGB')   \n",
    "image = transform(image_pil).unsqueeze(0)  \n",
    "\n",
    "caption = 'the woman is working on her computer at the desk'\n",
    "text = pre_caption(caption)\n",
    "text_input = tokenizer(text, return_tensors=\"pt\")\n",
    "\n",
    "if use_cuda:\n",
    "    image = image.cuda()\n",
    "    text_input = text_input.to(image.device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "gorgeous-matrix",
   "metadata": {},
   "source": [
    "# 7. Compute GradCAM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "driven-termination",
   "metadata": {},
   "outputs": [],
   "source": [
    "output = model(image, text_input)\n",
    "loss = output[:,1].sum()\n",
    "\n",
    "model.zero_grad()\n",
    "loss.backward()    \n",
    "\n",
    "with torch.no_grad():\n",
    "    mask = text_input.attention_mask.view(text_input.attention_mask.size(0),1,-1,1,1)\n",
    "\n",
    "    grads=model.text_encoder.base_model.base_model.encoder.layer[block_num].crossattention.self.get_attn_gradients()\n",
    "    cams=model.text_encoder.base_model.base_model.encoder.layer[block_num].crossattention.self.get_attention_map()\n",
    "\n",
    "    cams = cams[:, :, :, 1:].reshape(image.size(0), 12, -1, 24, 24) * mask\n",
    "    grads = grads[:, :, :, 1:].clamp(0).reshape(image.size(0), 12, -1, 24, 24) * mask\n",
    "\n",
    "    gradcam = cams * grads\n",
    "    gradcam = gradcam[0].mean(0).cpu().detach()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abroad-northern",
   "metadata": {},
   "source": [
    "# 8. Visualize GradCam for each word"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fourth-cache",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_image = len(text_input.input_ids[0]) \n",
    "fig, ax = plt.subplots(num_image, 1, figsize=(15,5*num_image))\n",
    "\n",
    "rgb_image = cv2.imread(image_path)[:, :, ::-1]\n",
    "rgb_image = np.float32(rgb_image) / 255\n",
    "\n",
    "ax[0].imshow(rgb_image)\n",
    "ax[0].set_yticks([])\n",
    "ax[0].set_xticks([])\n",
    "ax[0].set_xlabel(\"Image\")\n",
    "            \n",
    "for i,token_id in enumerate(text_input.input_ids[0][1:]):\n",
    "    word = tokenizer.decode([token_id])\n",
    "    gradcam_image = getAttMap(rgb_image, gradcam[i+1])\n",
    "    ax[i+1].imshow(gradcam_image)\n",
    "    ax[i+1].set_yticks([])\n",
    "    ax[i+1].set_xticks([])\n",
    "    ax[i+1].set_xlabel(word)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
